from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy
from Network.network_utils import reduce_function, get_acti, pytorch_model
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.factor_utils import final_conv_args, final_mlp_args
from Network.General.Factor.factored import return_values

def merge_key_queries(key, query, mask, append_keys=True, append_broadcast_mask = 0, append_mask=False, append_zero_keys=False):
    # mask of shape [batch, num_keys=1, num_queries]
    # keys: [batch, num_keys, key_dim], queries: [batch, num_queries, query_dim]
    # append_broadcast_mask is the size of the broadcasted mask appended to each query
    # append mask appends the mask to the key
    # append zero keys appends zeros in the shape of the keys instead of the keys (leaves room for appending keys later)
    n_queries = query.shape[1]
    # if mask is not None: print(mask.shape,torch.cat([keys, query * mask[:,0].unsqueeze(-1)], dim=-1).transpose(-1,-2).shape, torch.cat([keys, query * mask[:,0].unsqueeze(-1)], dim=-1).transpose(-1,-2)[:3])
    # TODO: pairnet key query would be handled here

    if append_broadcast_mask > 0: 
        # reverse appending mask direction so masked out is ones
        # print(mask.shape, mask.shape[0], mask.shape[-1], append_broadcast_mask, n_queries * append_broadcast_mask)
        if mask is not None: broadcast_mask = 1-torch.ones(key.shape[0],n_queries, device=query.get_device()).unsqueeze(-1).broadcast_to(mask.shape[0], mask.shape[-1], append_broadcast_mask)
        else: broadcast_mask = 1-mask[:,0].unsqueeze(-1).broadcast_to(mask.shape[0], mask.shape[-1], append_broadcast_mask)
    if append_keys or append_zero_keys:
        # appends the keys to the query embeddings, only valid if there is one key
        if mask is not None:
            if append_mask: key = torch.cat([key, mask], axis=-1) # assume both mask and key are single key [batch, 1, dim]
            # either append the key, or append zeros in the desired shape
            if append_keys: keys = key.broadcast_to(key.shape[0], n_queries, key.shape[-1])
            else: keys = torch.zeros(key.shape[0], n_queries, key.shape[-1], device=query.get_device())
            query = query * mask[:,0].unsqueeze(-1)
            if append_broadcast_mask > 0: query = torch.cat([query, broadcast_mask], axis=-1)
            return torch.cat([keys, query], dim=-1).transpose(-1,-2)
        keys = key.broadcast_to(key.shape[0], n_queries, key.shape[-1])
        return torch.cat([keys, query], dim=-1).transpose(-1,-2)
    if mask is not None:
        # append the broadcast mask to every query
        query = (query * mask[:,0].unsqueeze(-1))
        if append_broadcast_mask: query = torch.cat([query, broadcast_mask], axis=-1)
        # mask out the query by the mask
    else:
        if append_broadcast_mask: query = torch.cat([query, broadcast_mask], axis=-1)
    return query.transpose(-1,-2)

class PairNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        self.fp = args.factor
        self.no_decode = args.factor_net.no_decode
        self.reduce_function = args.factor_net.reduce_function
        self.num_layers = args.factor_net.num_pair_layers
        self.repeat_layers = args.factor_net.repeat_layers
        self.embed_dim = args.embed_dim
        self.last_dim = self.embed_dim if (self.embed_dim > 0 and not self.no_decode) else args.output_dim
        self.append_keys = args.factor_net.append_keys
        self.append_zero_keys = args.factor_net.append_zero_keys
        self.append_mask =  args.factor_net.append_mask
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask
        layers = list()

        # pairnets assume keys/queries are already embedded using key_query
        # args.factor.embed_dim is the embedded dimension
        # initialize the internal layers of the pointnet
        self.conv_layers = list()
        for i in range(self.num_layers):
            conv_args = copy.deepcopy(args)
            # TODO: make dependent on key and query dims if using feature masks
            kq_emb_dim = self.embed_dim * 2 + (int(self.append_mask) * self.fp.num_objects) if self.append_keys else self.embed_dim
            conv_args.object_dim = kq_emb_dim if self.embed_dim > 0 else args.object_dim + (int(self.append_mask) * self.fp.num_objects)
            conv_args.object_dim += self.append_broadcast_mask
            conv_args.output_dim = self.embed_dim if (self.embed_dim > 0 and not self.no_decode) or (i < self.num_layers - 1) else args.output_dim
            conv_args.activation_final = conv_args.activation if self.embed_dim > 0 else args.activation_final
            self.conv_args = conv_args
            # print (self.layer_conv_dim, self.hs[-1], args.num_outputs, self.conv_object_dim)
            if (not self.repeat_layers) or (i == self.num_layers - 1 and self.last_dim != self.embed_dim) or i == 0:
                self.conv_layers.append(ConvNetwork(conv_args))
        self.conv_layers = nn.ModuleList(self.conv_layers)
        layers.append(self.conv_layers)

        args.factor.final_embed_dim = self.embed_dim if self.embed_dim > 0 else args.factor.key_dim + args.factor.query_dim
        self.aggregate_final = args.aggregate_final
        # self.softmax = nn.Softmax(-1)
        if args.aggregate_final and not self.no_decode: # does not work with a post-channel
            final_args = final_mlp_args(args)
            self.decode = MLPNetwork(final_args)
            layers.append(self.decode)
        else:
            # need a network to go from the embed_dim to the object_dim
            if (not self.no_decode) and self.embed_dim > 0:
                final_args = final_conv_args(args)
                self.decode = ConvNetwork(final_args)
                layers.append(self.decode)

        self.model = layers
        self.train()
        self.reset_network_parameters()
    
    def forward(self, key, query, mask, ret_settings):
        # assumes only a single key, see keypair for multi-key networks
        x = merge_key_queries(key, query, mask, append_keys=self.append_keys, append_mask=self.append_mask, append_broadcast_mask=self.append_broadcast_mask) # [batch, embed_dim * 2, n_queries]
        for i in range(self.num_layers):
            layer_idx = (-1 if self.last_dim != self.embed_dim and i == self.num_layers-1 else 0) if self.repeat_layers else i
            # print(i, layer_idx, x.shape, self.num_layers - 1, key.shape, query.shape, mask.shape)
            x = self.conv_layers[layer_idx](x)
            if i < self.num_layers - 1:
                # print(x.shape, self.reduce_function, self.conv_layers)
                x = reduce_function(self.reduce_function, x) # TODO: no concatenation support
                # print(x.shape, query.shape, mask.shape)
                # append_keys should be true unless there is no point in running this
                x = merge_key_queries(x.unsqueeze(1), query, mask, append_broadcast_mask=self.append_broadcast_mask, append_mask=self.append_mask, append_keys=self.append_keys)
                # print(x.shape, self.no_decode)
            # print(x.shape)
        embeddings, reduction = x, None
        if self.aggregate_final:
            # combine the conv outputs using the reduce function, and append any post channels
            x = reduce_function(self.reduce_function, x)
            x = x.view(x.shape[0], -1)
            reduction = x
            # final network goes to batch, num_outputs
            if self.embed_dim > 0 and not self.no_decode: x = self.decode(x)
        else:
            # when dealing with num_query outputs
            if self.embed_dim > 0 and not self.no_decode: x = self.decode(x)
            x = x.transpose(2,1)
            x = x.reshape(x.shape[0], -1)
        return return_values(ret_settings, x, (key,query), embeddings, reduction)